Introduction¶
This project focuses on analyzing the Stroke Prediction Dataset to develop a machine learning model that can predict a patient's likelihood of suffering a stroke. According to the World Health Organization, strokes are the 2nd leading cause of death globally, responsible for approximately 11% of total deaths. This makes early prediction and prevention of strokes critically important.
The dataset contains various attributes about patients including:
- Demographics
- Age
- Gender
- Health factors
- Hypertension
- Heart disease
- Average glucose levels
- BMI
- Lifestyle
- Smoking status
- Work type
- Location
- Urban vs rural residence
The target variable is whether the patient suffered a stroke.
As a data analyst working with The Johns Hopkins Hospital, our objective is to thoroughly explore this data to uncover patterns and insights that can inform the development of a robust predictive model. This will enable doctors to identify high-risk patients and advise them and their families on precautionary measures.
The analysis will progress through the following key steps:
Exploratory Data Analysis - Examining the distributions, ranges, and relationships between the features and target variable through statistical summaries and visualizations. Checking data quality.
Statistical Inference - Formulating and testing hypotheses about stroke risk factors and quantifying uncertainty through confidence intervals.
Machine Learning Modeling - Applying a range of classification algorithms including logistic regression, decision trees, random forests and more to predict stroke likelihood. Tuning hyperparameters and building ensembles to optimize predictive performance.
Model Deployment - Selecting the top performing model and deploying it to enable real-time stroke risk prediction, potentially as a web app or containerized microservice.
Throughout this notebook, detailed commentary will be provided on the analytical approach, key findings, model results and ideas for further enhancement. The goal is to demonstrate a thoughtful, thorough analysis while documenting reproducible steps from data intake through model deployment.
By predicting stroke risk, this project aims to arm healthcare providers with a powerful tool to identify and engage high-risk patients, ultimately reducing the devastating impact of this condition. Let's begin the analysis to see what insights the data holds.
import warnings
warnings.filterwarnings("ignore")
import joblib
import numpy as np
import pandas as pd
import pingouin as pg
from IPython.display import Image
from catboost import CatBoostClassifier
from scipy import stats
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import (
StratifiedKFold,
train_test_split,
)
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from skopt import BayesSearchCV
from skopt.space import Categorical, Integer, Real
import xgboost as xgb
import lightgbm as lgb
import shap
from stroke_risk_predictor.utils.analysis_tools import (
plot_combined_histograms,
plot_combined_bar_charts,
plot_combined_boxplots,
plot_correlation_matrix,
flag_anomalies,
evaluate_model,
plot_model_performance,
plot_combined_confusion_matrices,
extract_feature_importances,
plot_feature_importances,
detect_anomalies_iqr,
calculate_cramers_v,
CustomVotingClassifier,
CustomLogisticRegressionWrapper,
)
stroke_df = pd.read_csv("../data/stroke_dataset.csv")
stroke_df.head()
| id | gender | age | hypertension | heart_disease | ever_married | work_type | Residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
| 1 | 51676 | Female | 61.0 | 0 | 0 | Yes | Self-employed | Rural | 202.21 | NaN | never smoked | 1 |
| 2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
| 3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
| 4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
duplicates = stroke_df.duplicated().sum()
print(f"Number of duplicate rows: {duplicates}")
if duplicates > 0:
stroke_df = stroke_df.drop_duplicates()
print("Duplicates removed.")
Number of duplicate rows: 0
Great, we can see that there are no duplicates in the dataset, therefore we can move forward.
stroke_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5110 entries, 0 to 5109 Data columns (total 12 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 id 5110 non-null int64 1 gender 5110 non-null object 2 age 5110 non-null float64 3 hypertension 5110 non-null int64 4 heart_disease 5110 non-null int64 5 ever_married 5110 non-null object 6 work_type 5110 non-null object 7 Residence_type 5110 non-null object 8 avg_glucose_level 5110 non-null float64 9 bmi 4909 non-null float64 10 smoking_status 5110 non-null object 11 stroke 5110 non-null int64 dtypes: float64(3), int64(4), object(5) memory usage: 479.2+ KB
Great, we can see that the dataset contains a mix of integer, float, and object data types, which are appropriate for the corresponding variables. That being said, we can check for missing values.
print(stroke_df.isnull().sum())
id 0 gender 0 age 0 hypertension 0 heart_disease 0 ever_married 0 work_type 0 Residence_type 0 avg_glucose_level 0 bmi 201 smoking_status 0 stroke 0 dtype: int64
This dataset contains 5110 entries and 12 columns related to potential stroke risk factors.
Quick Facts:
- Features:
id, gender, age, hypertension, heart_disease, ever_married, work_type, Residence_type, avg_glucose_level, bmi, smoking_status - Target Variable: stroke (binary: 0 or 1)
- Data Types: Mixture of numerical (int64, float64) and categorical (object) features
- Missing Values: 201 in 'bmi' column (3.93% of dataset)
Key Observations:
- Diverse risk factors: demographic, health conditions, lifestyle, and biometric measurements
- Binary target variable (stroke occurrence)
- Potential for class imbalance in target variable (to be checked)
Initial Steps:
- Clean data: rename columns, handle missing values.
- Explore feature distributions and relationships with target
- Conduct statistical tests to validate risk factor relationships
stroke_df = stroke_df.rename(columns={"Residence_type": "residence_type"})
In this case, we will handle missing values in the bmi column by dropping the rows with missing values, as they account for only 3.93% of the dataset.
stroke_df = stroke_df.dropna(subset=["bmi"])
stroke_df.head()
| id | gender | age | hypertension | heart_disease | ever_married | work_type | residence_type | avg_glucose_level | bmi | smoking_status | stroke | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 |
| 2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 |
| 3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 |
| 4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 |
| 5 | 56669 | Male | 81.0 | 0 | 0 | Yes | Private | Urban | 186.21 | 29.0 | formerly smoked | 1 |
With missing values in bmi handled and features renamed, let's examine the dataset structure.
print(stroke_df.describe().T)
count mean std min 25% \
id 4909.0 37064.313506 20995.098457 77.00 18605.00
age 4909.0 42.865374 22.555115 0.08 25.00
hypertension 4909.0 0.091872 0.288875 0.00 0.00
heart_disease 4909.0 0.049501 0.216934 0.00 0.00
avg_glucose_level 4909.0 105.305150 44.424341 55.12 77.07
bmi 4909.0 28.893237 7.854067 10.30 23.50
stroke 4909.0 0.042575 0.201917 0.00 0.00
50% 75% max
id 37608.00 55220.00 72940.00
age 44.00 60.00 82.00
hypertension 0.00 0.00 1.00
heart_disease 0.00 0.00 1.00
avg_glucose_level 91.68 113.57 271.74
bmi 28.10 33.10 97.60
stroke 0.00 0.00 1.00
Current Observations:
- Numerical Features:
age: Average 42.87 years, range 0.08 to 82.avg_glucose_level: Average 105.31, large standard deviation (44.42).bmi: Average 28.89, range 10.30 to 97.60.
- Binary Features:
hypertension,heart_disease, andstrokeare binary (0 or 1).- Low prevalence of hypertension and heart disease.
stroke(target variable) has low prevalence (about 4%), indicating class imbalance.
Next Step: Analyze Distributions of All Variables
Prior to encoding, it's crucial to comprehensively analyze the distributions of both numerical and categorical variables. This analysis will provide valuable insights into our dataset's characteristics and guide our encoding and preprocessing strategies.
We should proceed as follows:
Numerical Variables:
- Create histograms and box plots for
age,avg_glucose_level, andbmi. - Look for outliers, skewness, and any unusual patterns.
- Consider if any transformations (e.g., log transformation) might be beneficial.
- Create histograms and box plots for
Binary Variables:
- Create bar plots for
hypertension,heart_disease, andstroke. - Quantify the exact prevalence of each condition.
- For
stroke, our target variable, consider strategies to handle class imbalance.
- Create bar plots for
Categorical Variables:
- Create bar plots for
gender,ever_married,work_type,residence_type, andsmoking_status. - Examine the distribution of categories within each variable.
- Look for any categories with very low frequency, which might need special handling.
- Create bar plots for
numerical_features = ["age", "avg_glucose_level", "bmi"]
plot_combined_histograms(
stroke_df,
numerical_features,
nbins=30,
save_path="../images/numerical_distributions.png",
)
Image(filename="../images/numerical_distributions.png")
The histograms reveal the following about age, avg_glucose_level, and bmi:
Age: Shows a relatively uniform distribution across most age ranges, with slight increases in frequency for middle-aged adults (around 45-65). There's a noticeable drop-off for very young (<20) and very old (>80) ages. This uniform distribution is unusual for demographic data and may warrant further investigation into the data collection process or potential sampling biases.
Average Glucose Level: Strongly right-skewed, with a peak around 90-100 mg/dL and a long tail extending to higher values. There's a secondary smaller peak around 200-250 mg/dL, which could indicate a subgroup with diabetes or pre-diabetes.
BMI: Approximately normally distributed, centered around 25-30, with a slight right skew. There are notable outliers at very high BMI values (>60) that warrant further investigation.
Next up, we can move on to the categorical features.
categorical_features = [
"gender",
"hypertension",
"heart_disease",
"ever_married",
"work_type",
"residence_type",
"smoking_status",
"stroke",
]
categorical_features_set1 = [
"gender",
"hypertension",
"heart_disease",
"smoking_status",
]
categorical_features_set2 = ["ever_married", "work_type", "residence_type", "stroke"]
plot_combined_bar_charts(
stroke_df,
categorical_features_set1,
max_features_per_plot=4,
save_path="../images/categorical_distributions_set1",
)
Image(filename="../images/categorical_distributions_set1_chunk_1.png")
The bar plots reveal the following about gender, hypertension, heart_disease, and smoking_status
Gender:
- The dataset contains more females than males.
- There's a very small number of "Other" gender entries, which may need special handling in the analysis.
Hypertension:
- Highly imbalanced distribution.
- The vast majority of patients do not have hypertension (value 0).
- This imbalance will need to be addressed in the modeling phase to prevent bias.
Heart Disease:
- Similar to hypertension, there's a significant imbalance.
- Most patients in the dataset do not have heart disease (value 0).
- This imbalance also requires attention during model development.
Smoking Status:
- "Never smoked" is the most common category.
- There's a significant number of "Unknown" entries, which may require special handling.
- "Formerly smoked" and "smokes" categories have lower, but similar frequencies.
- The high number of "Unknown" entries could impact the analysis and may need imputation or special treatment.
plot_combined_bar_charts(
stroke_df,
categorical_features_set2,
max_features_per_plot=4,
save_path="../images/categorical_distributions_set2",
)
Image(filename="../images/categorical_distributions_set2_chunk_1.png")
The bar plots reveal the following about ever_married, work_type, residence_type and stroke
Ever Married:
- More married individuals ("Yes") than unmarried ("No") in the dataset.
- This could be correlated with age and might provide insights when analyzed together.
Work Type:
- "Private" is the most common category, followed by "Self-employed".
- "Govt_job" and "children" categories have similar, lower frequencies.
- There are very few "Never_worked" entries.
- The "children" category might overlap with the younger age group, warranting further investigation.
Residence Type:
- Nearly equal distribution between Urban and Rural residences.
- This balance is good for analyzing the impact of residence type on stroke risk without bias from uneven representation.
Stroke:
- The vast majority of individuals (about 4000) are in the "0" category, which represents no stroke.
- A much smaller number (less than 500) are in the "1" category, representing those who have had a stroke.
- This imbalance in the target variable will need to be addressed during model development.
Next, we can move on to checking the outliers in the numerical features.
plot_combined_boxplots(
stroke_df, numerical_features, save_path="../images/numerical_boxplots.png"
)
Image(filename="../images/numerical_boxplots.png")
We can see that there are a few outliers, therefore we need to investigate them further.
anomalies = detect_anomalies_iqr(stroke_df, numerical_features)
print("Detected anomalies:")
print(anomalies)
No anomalies detected in feature 'age'.
Anomalies detected in feature 'avg_glucose_level':
id gender age hypertension heart_disease ever_married \
0 9046 Male 67.0 0 1 Yes
3 60182 Female 49.0 0 0 Yes
4 1665 Female 79.0 1 0 Yes
5 56669 Male 81.0 0 0 Yes
14 5317 Female 79.0 0 1 Yes
... ... ... ... ... ... ...
5061 38009 Male 41.0 0 0 Yes
5062 11184 Female 82.0 0 0 Yes
5063 68967 Male 39.0 0 0 Yes
5064 66684 Male 70.0 0 0 Yes
5076 39935 Female 34.0 0 0 Yes
work_type residence_type avg_glucose_level bmi smoking_status \
0 Private Urban 228.69 36.6 formerly smoked
3 Private Urban 171.23 34.4 smokes
4 Self-employed Rural 174.12 24.0 never smoked
5 Private Urban 186.21 29.0 formerly smoked
14 Private Urban 214.09 28.2 never smoked
... ... ... ... ... ...
5061 Private Urban 223.78 32.3 never smoked
5062 Self-employed Rural 211.58 36.9 never smoked
5063 Private Urban 179.38 27.7 Unknown
5064 Self-employed Rural 193.88 24.3 Unknown
5076 Private Rural 174.37 23.0 never smoked
stroke
0 1
3 1
4 1
5 1
14 1
... ...
5061 0
5062 0
5063 0
5064 0
5076 0
[567 rows x 12 columns]
Anomalies detected in feature 'bmi':
id gender age hypertension heart_disease ever_married \
21 13861 Female 52.0 1 0 Yes
113 41069 Female 45.0 0 0 Yes
254 32257 Female 47.0 0 0 Yes
258 28674 Female 74.0 1 0 Yes
270 72911 Female 57.0 1 0 Yes
... ... ... ... ... ... ...
4858 1696 Female 43.0 0 0 Yes
4906 72696 Female 53.0 0 0 Yes
4952 16245 Male 51.0 1 0 Yes
5009 40732 Female 50.0 0 0 Yes
5057 38349 Female 49.0 0 0 Yes
work_type residence_type avg_glucose_level bmi smoking_status \
21 Self-employed Urban 233.29 48.9 never smoked
113 Private Rural 224.10 56.6 never smoked
254 Private Urban 210.95 50.1 Unknown
258 Self-employed Urban 205.84 54.6 never smoked
270 Private Rural 129.54 60.9 smokes
... ... ... ... ... ...
4858 Private Urban 100.88 47.6 smokes
4906 Private Urban 70.51 54.1 never smoked
4952 Self-employed Rural 211.83 56.6 never smoked
5009 Self-employed Rural 126.85 49.5 formerly smoked
5057 Govt_job Urban 69.92 47.6 never smoked
stroke
21 1
113 1
254 0
258 0
270 0
... ...
4858 0
4906 0
4952 0
5009 0
5057 0
[110 rows x 12 columns]
Detected anomalies:
age avg_glucose_level bmi
0 67.0 228.69 36.6
1 49.0 171.23 34.4
2 79.0 174.12 24.0
3 81.0 186.21 29.0
4 79.0 214.09 28.2
.. ... ... ...
644 30.0 84.92 47.8
645 43.0 100.88 47.6
646 53.0 70.51 54.1
647 50.0 126.85 49.5
648 49.0 69.92 47.6
[649 rows x 3 columns]
Our analysis revealed the presence of outliers in the dataset. After careful consideration, we have decided to retain these outliers for the following reasons:
1. Domain-Specific Considerations
- Medical Significance: In healthcare datasets, extreme values often represent clinically significant cases.
- Preserving Information: Removing outliers without domain expertise risks losing valuable insights.
2. Dataset Characteristics
- Class Imbalance: The dataset exhibits an imbalanced distribution, with rare occurrences of the target variable (stroke).
- Rare Case Representation: Eliminating outliers could further reduce the already limited representation of these critical cases.
3. Model Robustness
- Diverse Training Data: Including outliers helps develop models that are more robust and generalize better across a wide range of scenarios.
- Avoiding Overfitting: Retaining outliers can prevent models from becoming overly sensitive to a narrow range of data points.
4. Proposed Approach
To balance the need for data integrity with the potential impact of outliers, we propose the following strategy:
- Outlier Flagging: Introduce a new binary feature called
has_anomaliesto identify potential outliers. - Flexible Handling: This approach allows for targeted treatment of outliers in subsequent analyses and modeling stages.
5. Benefits of This Strategy
- Data Integrity: Preserves the original dataset without loss of potentially crucial information.
- Analytical Flexibility: Enables customized handling of outliers based on specific requirements of each analysis or modeling task.
- Transparency: Clearly identifies potential anomalies for further investigation or specialized treatment.
By adopting this nuanced approach to outlier management, we aim to maintain the dataset's integrity while providing the flexibility needed for robust analysis and modeling.
stroke_df["has_anomalies"] = flag_anomalies(stroke_df, numerical_features)
stroke_df["has_anomalies"].value_counts()
has_anomalies False 4260 True 649 Name: count, dtype: int64
stroke_df.head()
| id | gender | age | hypertension | heart_disease | ever_married | work_type | residence_type | avg_glucose_level | bmi | smoking_status | stroke | has_anomalies | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | Male | 67.0 | 0 | 1 | Yes | Private | Urban | 228.69 | 36.6 | formerly smoked | 1 | True |
| 2 | 31112 | Male | 80.0 | 0 | 1 | Yes | Private | Rural | 105.92 | 32.5 | never smoked | 1 | False |
| 3 | 60182 | Female | 49.0 | 0 | 0 | Yes | Private | Urban | 171.23 | 34.4 | smokes | 1 | True |
| 4 | 1665 | Female | 79.0 | 1 | 0 | Yes | Self-employed | Rural | 174.12 | 24.0 | never smoked | 1 | True |
| 5 | 56669 | Male | 81.0 | 0 | 0 | Yes | Private | Urban | 186.21 | 29.0 | formerly smoked | 1 | True |
plot_correlation_matrix(
stroke_df,
numerical_features + ["stroke"],
save_path="../images/correlation_matrix.png",
)
Image(filename="../images/correlation_matrix.png")
Correlation Matrix Analysis
The correlation matrix visually represents the pairwise correlations between key numerical variables in our dataset:
- Age
- Average glucose level
- BMI (Body Mass Index)
- Stroke (target variable)
Key Interpretations
| Relationship | Correlation | Interpretation |
|---|---|---|
| Age and Stroke | 0.23 | Strongest correlation; suggests elevated stroke risk with age |
| Average Glucose Level and Stroke | 0.14 | Moderate correlation; higher blood sugar might increase stroke risk |
| BMI and Stroke | 0.04 | Weak positive correlation; slight association between higher BMI and stroke risk |
| Age and BMI | 0.33 | Moderate positive correlation; older individuals tend to have higher BMI |
| Age and Average Glucose Level | 0.24 | Weak positive correlation; glucose levels tend to increase slightly with age |
| BMI and Average Glucose Level | 0.18 | Weak positive correlation; higher BMI slightly associated with higher glucose levels |
Interpretation Guidelines
- Strong correlation: |r| > 0.5
- Moderate correlation: 0.3 < |r| ≤ 0.5
- Weak correlation: 0.1 < |r| ≤ 0.3
- Very weak correlation: |r| ≤ 0.1
Additional Considerations
- Correlations do not imply causation.
- Some relationships may be non-linear and require further investigation.
- Confounding factors may influence observed correlations.
We are going to keep all of the features because:
- There isn't strong multicollinearity between the predictors (highest correlation is 0.33).
- All features show some level of correlation with the target variable, potentially providing predictive power.
- Removing features based solely on correlation might lead to loss of important information.
Therefore, before moving further with modeling, we can proceed with encoding the categorical features.
Based on our distribution analysis, we identified several categorical features that need encoding. Our encoding strategy will be as follows:
For binary categorical features (those with 2 unique values), we will use label encoding. This is appropriate because there's no implicit ordering, and it's a simple 0/1 representation.
For categorical features with more than 2 unique values, we will use one-hot encoding. This avoids introducing an arbitrary ordinal relationship between categories.
binary_features = ["ever_married", "residence_type"]
label_encoder = LabelEncoder()
for feature in binary_features:
stroke_df[feature] = label_encoder.fit_transform(stroke_df[feature])
stroke_df["has_anomalies"] = stroke_df["has_anomalies"].astype(int)
Next up we use one hot encoding for categorical features with more than 2 unique values.
onehot_features = ["gender", "work_type", "smoking_status"]
onehot_encoder = OneHotEncoder(sparse_output=False)
onehot_encoded = onehot_encoder.fit_transform(stroke_df[onehot_features])
onehot_columns = onehot_encoder.get_feature_names_out(onehot_features)
column_mapping = {}
for feature, categories in zip(onehot_features, onehot_encoder.categories_):
for category in categories:
old_name = f"{feature}_{category}"
new_name = f"{feature}_{category.lower().replace(' ', '_')}"
column_mapping[old_name] = new_name
onehot_columns = [column_mapping.get(col, col) for col in onehot_columns]
stroke_df = stroke_df.drop(columns=onehot_features)
stroke_df[onehot_columns] = onehot_encoded
print(stroke_df.head())
id age hypertension heart_disease ever_married residence_type \ 0 9046 67.0 0 1 1 1 2 31112 80.0 0 1 1 0 3 60182 49.0 0 0 1 1 4 1665 79.0 1 0 1 0 5 56669 81.0 0 0 1 1 avg_glucose_level bmi stroke has_anomalies ... gender_other \ 0 228.69 36.6 1 1 ... 0.0 2 105.92 32.5 1 0 ... 0.0 3 171.23 34.4 1 1 ... 0.0 4 174.12 24.0 1 1 ... 0.0 5 186.21 29.0 1 1 ... 0.0 work_type_govt_job work_type_never_worked work_type_private \ 0 0.0 0.0 1.0 2 0.0 0.0 1.0 3 0.0 0.0 1.0 4 0.0 0.0 0.0 5 0.0 0.0 1.0 work_type_self-employed work_type_children smoking_status_unknown \ 0 0.0 0.0 0.0 2 0.0 0.0 0.0 3 0.0 0.0 0.0 4 1.0 0.0 0.0 5 0.0 0.0 0.0 smoking_status_formerly_smoked smoking_status_never_smoked \ 0 1.0 0.0 2 0.0 1.0 3 0.0 0.0 4 0.0 1.0 5 1.0 0.0 smoking_status_smokes 0 0.0 2 0.0 3 1.0 4 0.0 5 0.0 [5 rows x 22 columns]
Lastly, we can move to our statistical inference.
Statistical Inference
Highlights:
- Investigate the relationships between age, glucose level, BMI, hypertension, heart disease, and stroke occurrence
- Conduct t-tests for continuous variables and chi-square tests for categorical variables
- Report p-values, effect sizes, and confidence intervals
- Check assumptions and apply multiple comparison adjustments if needed
Target Population and Sample: The target population is adults at risk of stroke. The sample consists of 4,909 individuals with diverse demographic and health characteristics.
Significance Level: α = 0.05
Hypotheses and Tests:
Age and Stroke Risk
- H0: No difference in mean age between stroke and non-stroke groups
- H1: Significant difference in mean age between stroke and non-stroke groups
- Test: Independent samples t-test (two-tailed)
- Effect size: Cohen's d
Glucose Level and Stroke Risk
- H0: No difference in mean glucose levels between stroke and non-stroke groups
- H1: Significant difference in mean glucose levels between stroke and non-stroke groups
- Test: Independent samples t-test (two-tailed)
- Effect size: Cohen's d
BMI and Stroke Risk
- H0: No difference in mean BMI between stroke and non-stroke groups
- H1: Significant difference in mean BMI between stroke and non-stroke groups
- Test: Independent samples t-test (two-tailed)
- Effect size: Cohen's d
Hypertension and Stroke Risk
- H0: No association between hypertension and stroke occurrence
- H1: Significant association between hypertension and stroke occurrence
- Test: Chi-square test of independence
- Effect size: Odds ratio, Cramer's V
Heart Disease and Stroke Risk
- H0: No association between heart disease and stroke occurrence
- H1: Significant association between heart disease and stroke occurrence
- Test: Chi-square test of independence
- Effect size: Odds ratio, Cramer's V
Confidence Intervals (95%):
- Mean Age of Stroke Patients
- Mean Glucose Level of Stroke Patients
- Mean BMI of Stroke Patients
Assumptions and Corrections:
- Check normality and equal variances for t-tests
- Check independence and expected cell counts for chi-square tests
- Apply multiple comparison adjustments (e.g., Bonferroni correction) if needed
stroke_age = stroke_df[stroke_df["stroke"] == 1]["age"]
non_stroke_age = stroke_df[stroke_df["stroke"] == 0]["age"]
age_ttest = stats.ttest_ind(stroke_age, non_stroke_age)
age_cohen_d = pg.compute_effsize(stroke_age, non_stroke_age, eftype="cohen")
print("Age and Stroke Risk:")
print(f"T-test results: t={age_ttest.statistic:.3f}, p={age_ttest.pvalue:.3f}")
print(f"Cohen's d: {age_cohen_d:.3f}")
Age and Stroke Risk: T-test results: t=16.733, p=0.000 Cohen's d: 1.183
stroke_glucose = stroke_df[stroke_df["stroke"] == 1]["avg_glucose_level"]
non_stroke_glucose = stroke_df[stroke_df["stroke"] == 0]["avg_glucose_level"]
glucose_ttest = stats.ttest_ind(stroke_glucose, non_stroke_glucose)
glucose_cohen_d = pg.compute_effsize(stroke_glucose, non_stroke_glucose, eftype="cohen")
print("\nGlucose Level and Stroke Risk:")
print(f"T-test results: t={glucose_ttest.statistic:.3f}, p={glucose_ttest.pvalue:.3f}")
print(f"Cohen's d: {glucose_cohen_d:.3f}")
Glucose Level and Stroke Risk: T-test results: t=9.828, p=0.000 Cohen's d: 0.695
stroke_bmi = stroke_df[stroke_df["stroke"] == 1]["bmi"]
non_stroke_bmi = stroke_df[stroke_df["stroke"] == 0]["bmi"]
bmi_ttest = stats.ttest_ind(stroke_bmi, non_stroke_bmi)
bmi_cohen_d = pg.compute_effsize(stroke_bmi, non_stroke_bmi, eftype="cohen")
print("\nBMI and Stroke Risk:")
print(f"T-test results: t={bmi_ttest.statistic:.3f}, p={bmi_ttest.pvalue:.3f}")
print(f"Cohen's d: {bmi_cohen_d:.3f}")
BMI and Stroke Risk: T-test results: t=2.971, p=0.003 Cohen's d: 0.210
hypertension_contingency = pd.crosstab(stroke_df["hypertension"], stroke_df["stroke"])
hypertension_chi2 = stats.chi2_contingency(hypertension_contingency)
odds_ratio, _ = stats.fisher_exact(hypertension_contingency)
cramers_v = calculate_cramers_v(hypertension_contingency)
print("\nHypertension and Stroke Risk:")
print(
f"Chi-square results: chi2={hypertension_chi2[0]:.3f}, p={hypertension_chi2[1]:.3f}"
)
print(f"Odds ratio: {odds_ratio:.3f}")
print(f"Cramer's V: {cramers_v:.3f}")
Hypertension and Stroke Risk: Chi-square results: chi2=97.275, p=0.000 Odds ratio: 4.438 Cramer's V: 0.141
heart_disease_contingency = pd.crosstab(stroke_df["heart_disease"], stroke_df["stroke"])
heart_disease_chi2 = stats.chi2_contingency(heart_disease_contingency)
odds_ratio, _ = stats.fisher_exact(heart_disease_contingency)
cramers_v = calculate_cramers_v(heart_disease_contingency)
print("\nHeart Disease and Stroke Risk:")
print(
f"Chi-square results: chi2={heart_disease_chi2[0]:.3f}, p={heart_disease_chi2[1]:.3f}"
)
print(f"Odds ratio: {odds_ratio:.3f}")
print(f"Cramer's V: {cramers_v:.3f}")
Heart Disease and Stroke Risk: Chi-square results: chi2=90.280, p=0.000 Odds ratio: 5.243 Cramer's V: 0.136
print("\nConfidence Intervals (95%):")
print(
f"Mean Age of Stroke Patients: {stats.t.interval(0.95, len(stroke_age) - 1, loc=np.mean(stroke_age), scale=stats.sem(stroke_age))}"
)
print(
f"Mean Glucose Level of Stroke Patients: {stats.t.interval(0.95, len(stroke_glucose) - 1, loc=np.mean(stroke_glucose), scale=stats.sem(stroke_glucose))}"
)
print(
f"Mean BMI of Stroke Patients: {stats.t.interval(0.95, len(stroke_bmi) - 1, loc=np.mean(stroke_bmi), scale=stats.sem(stroke_bmi))}"
)
Confidence Intervals (95%): Mean Age of Stroke Patients: (66.02157958992271, 69.40425773065147) Mean Glucose Level of Stroke Patients: (126.0536264424378, 143.08914867717942) Mean BMI of Stroke Patients: (29.608163593319265, 31.33442013873815)
Statistical Tests Results
Age: t = 16.733, p < 0.001, Cohen's d = 1.183 CI (95%): 66.02 - 69.40 years (stroke patients)
Glucose Level: t = 9.828, p < 0.001, Cohen's d = 0.695 CI (95%): 126.05 - 143.09 mg/dL (stroke patients)
BMI: t = 2.971, p = 0.003, Cohen's d = 0.210 CI (95%): 29.61 - 31.33 (stroke patients)
Hypertension: χ² = 90.280, p < 0.001, Odds ratio = 5.243, Cramer's V = 0.136
Heart Disease: χ² = 90.280, p < 0.001, Odds ratio = 5.243, Cramer's V = 0.136
Key Findings
- All tested factors show statistically significant associations with stroke risk (p < 0.05).
- Age has the strongest relationship (large effect size), followed by glucose level (medium effect size).
- Hypertension and heart disease both increase stroke odds by about 5 times.
- BMI shows a significant but small effect on stroke risk.
Implications for Stroke Prediction Model
- Prioritize age and glucose level as key features in the model.
- Include hypertension and heart disease as important binary predictors.
- Consider BMI as a supplementary feature, possibly in interaction with other factors.
Next Steps:
- We can move to feature engineering based on our findings.
stroke_df["age_glucose"] = stroke_df["age"] * stroke_df["avg_glucose_level"]
stroke_df["age_hypertension"] = stroke_df["age"] * stroke_df["hypertension"]
stroke_df["age_heart_disease"] = stroke_df["age"] * stroke_df["heart_disease"]
stroke_df["age_squared"] = stroke_df["age"] ** 2
stroke_df["glucose_squared"] = stroke_df["avg_glucose_level"] ** 2
stroke_df["bmi_age"] = stroke_df["bmi"] * stroke_df["age"]
stroke_df["bmi_glucose"] = stroke_df["bmi"] * stroke_df["avg_glucose_level"]
stroke_df.head()
| id | age | hypertension | heart_disease | ever_married | residence_type | avg_glucose_level | bmi | stroke | has_anomalies | ... | smoking_status_formerly_smoked | smoking_status_never_smoked | smoking_status_smokes | age_glucose | age_hypertension | age_heart_disease | age_squared | glucose_squared | bmi_age | bmi_glucose | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 9046 | 67.0 | 0 | 1 | 1 | 1 | 228.69 | 36.6 | 1 | 1 | ... | 1.0 | 0.0 | 0.0 | 15322.23 | 0.0 | 67.0 | 4489.0 | 52299.1161 | 2452.2 | 8370.054 |
| 2 | 31112 | 80.0 | 0 | 1 | 1 | 0 | 105.92 | 32.5 | 1 | 0 | ... | 0.0 | 1.0 | 0.0 | 8473.60 | 0.0 | 80.0 | 6400.0 | 11219.0464 | 2600.0 | 3442.400 |
| 3 | 60182 | 49.0 | 0 | 0 | 1 | 1 | 171.23 | 34.4 | 1 | 1 | ... | 0.0 | 0.0 | 1.0 | 8390.27 | 0.0 | 0.0 | 2401.0 | 29319.7129 | 1685.6 | 5890.312 |
| 4 | 1665 | 79.0 | 1 | 0 | 1 | 0 | 174.12 | 24.0 | 1 | 1 | ... | 0.0 | 1.0 | 0.0 | 13755.48 | 79.0 | 0.0 | 6241.0 | 30317.7744 | 1896.0 | 4178.880 |
| 5 | 56669 | 81.0 | 0 | 0 | 1 | 1 | 186.21 | 29.0 | 1 | 1 | ... | 1.0 | 0.0 | 0.0 | 15083.01 | 0.0 | 0.0 | 6561.0 | 34674.1641 | 2349.0 | 5400.090 |
5 rows × 29 columns
Model Development Phase
Objective
Our primary aim is to construct a predictive model capable of:
- Identifying potential stroke cases with high sensitivity (recall)
- Maintaining an acceptable level of specificity (precision)
Key Performance Metrics
- Recall (Sensitivity): Maximize to reduce the number of undetected stroke cases
- Precision: Optimize to minimize false positive rates
Strategic Focus
We will prioritize recall over precision to ensure:
- Minimal oversight of actual stroke cases
- Acceptable rate of false alarms, balancing healthcare resource utilization
This approach aligns with the critical nature of stroke diagnosis, where early detection and intervention are paramount for patient outcomes.
X = stroke_df.drop(["stroke", "id"], axis=1)
y = stroke_df["stroke"]
X_train_val, X_test, y_train_val, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.25, random_state=42, stratify=y_train_val
)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)
models = {
"Logistic Regression": LogisticRegression(
class_weight="balanced", random_state=42, max_iter=1000
),
"XGBoost": xgb.XGBClassifier(
scale_pos_weight=len(y_train[y_train == 0]) / len(y_train[y_train == 1]),
random_state=42,
),
"LightGBM": lgb.LGBMClassifier(class_weight="balanced", random_state=42),
"CatBoost": CatBoostClassifier(
class_weights={
0: 1,
1: len(y_train[y_train == 0]) / len(y_train[y_train == 1]),
},
random_state=42,
verbose=False,
),
}
val_results = {}
val_predictions = {}
feature_importances = {}
for name, model in models.items():
X_train_data = X_train_scaled if name == "Logistic Regression" else X_train
X_val_data = X_val_scaled if name == "Logistic Regression" else X_val
model.fit(X_train_data, y_train)
val_results[name] = evaluate_model(model, X_val_data, y_val)
val_predictions[name] = model.predict(X_val_data)
feature_importances[name] = dict(
zip(X.columns, extract_feature_importances(model, X_val_data, y_val))
)
precision recall f1-score support
0 0.99 0.74 0.85 940
1 0.13 0.86 0.22 42
accuracy 0.74 982
macro avg 0.56 0.80 0.53 982
weighted avg 0.95 0.74 0.82 982
Confusion Matrix:
[[695 245]
[ 6 36]]
ROC AUC: 0.8444
PR AUC: 0.1721
F1 Score: 0.2229
Precision: 0.1281
Recall: 0.8571
Balanced Accuracy: 0.7983
precision recall f1-score support
0 0.96 0.97 0.97 940
1 0.13 0.10 0.11 42
accuracy 0.93 982
macro avg 0.55 0.53 0.54 982
weighted avg 0.92 0.93 0.93 982
Confusion Matrix:
[[914 26]
[ 38 4]]
ROC AUC: 0.7822
PR AUC: 0.1078
F1 Score: 0.1111
Precision: 0.1333
Recall: 0.0952
Balanced Accuracy: 0.5338
[LightGBM] [Info] Number of positive: 125, number of negative: 2820
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000721 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 1830
[LightGBM] [Info] Number of data points in the train set: 2945, number of used features: 25
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Info] Start training from score 0.000000
precision recall f1-score support
0 0.96 0.96 0.96 940
1 0.11 0.10 0.10 42
accuracy 0.93 982
macro avg 0.53 0.53 0.53 982
weighted avg 0.92 0.93 0.92 982
Confusion Matrix:
[[906 34]
[ 38 4]]
ROC AUC: 0.8036
PR AUC: 0.1265
F1 Score: 0.1000
Precision: 0.1053
Recall: 0.0952
Balanced Accuracy: 0.5295
precision recall f1-score support
0 0.96 0.95 0.96 940
1 0.16 0.19 0.17 42
accuracy 0.92 982
macro avg 0.56 0.57 0.57 982
weighted avg 0.93 0.92 0.93 982
Confusion Matrix:
[[897 43]
[ 34 8]]
ROC AUC: 0.8077
PR AUC: 0.1332
F1 Score: 0.1720
Precision: 0.1569
Recall: 0.1905
Balanced Accuracy: 0.5724
metrics_to_plot = [
"roc_auc",
"pr_auc",
"f1",
"precision",
"recall",
"balanced_accuracy",
]
plot_model_performance(
val_results, metrics_to_plot, save_path="../images/initial_model_performance.png"
)
plot_combined_confusion_matrices(
val_results,
y_val,
val_predictions,
labels=["No Stroke", "Stroke"],
save_path="../images/initial_confusion_matrices.png",
)
plot_feature_importances(
feature_importances,
save_path="../images/initial_validation_feature_importances.png",
)
Image(filename="../images/initial_model_performance.png")
Image(filename="../images/initial_confusion_matrices.png")
Image(filename="../images/initial_feature_importances.png")
Model Performance Comparison
Based on the performance metrics:
- Logistic Regression shows the highest recall (0.86) and ROC AUC (0.84), aligning best with our primary objective of maximizing sensitivity.
- CatBoost offers the best balance between precision (0.16) and recall (0.19), resulting in the highest F1 score (0.17).
- XGBoost and LightGBM have high precision but low recall, which doesn't align with our primary goal.
Confusion Matrices Analysis
From the confusion matrices:
- Logistic Regression correctly identifies the most stroke cases (36 TP out of 42), aligning with our goal of high sensitivity.
- CatBoost shows a more balanced performance, with 8 true positives and 43 false positives.
- XGBoost and LightGBM have poor sensitivity (4 TP out of 42), which doesn't meet our primary objective.
Feature Importance
Key findings:
- Age is consistently the most important feature across all models.
- Average glucose level and BMI are also significant predictors.
- Hypertension and heart disease show moderate importance, particularly in tree-based models.
Initial Conclusions
- Logistic Regression aligns best with our primary goal of maximizing recall.
- CatBoost offers a good balance between recall and precision, which could be valuable for minimizing false alarms while maintaining high sensitivity.
- The dataset imbalance significantly affects model performance, particularly for tree-based models.
Next Steps
- Focus on Logistic Regression and CatBoost: These two models show the most promise for our objective. We'll optimize them further.
- Hyperparameter Tuning: Use RandomizedSearchCV to find better hyperparameters for both models, with a focus on maximizing recall.
- Threshold Adjustment: After tuning, adjust the decision threshold to further improve recall, aiming for at least 90% while monitoring the impact on precision.
- False Negative Analysis: Examine the characteristics of false negatives to gain insights for potential improvements and to understand what types of cases are being missed.
n_negative = np.sum(y_train == 0)
n_positive = np.sum(y_train == 1)
class_weight_value = n_negative / n_positive
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
scoring = {
"recall": "recall",
"precision": "precision",
"roc_auc": "roc_auc",
"avg_precision": "average_precision",
}
# Define parameter spaces
lr_param_space = {
"C": Real(0.1, 10, prior="log-uniform"),
"class_weight": Categorical(
["balanced", None]
), # Will handle custom weights separately
"solver": Categorical(["newton-cg", "lbfgs", "saga"]),
"max_iter": Integer(1000, 50000),
}
cat_param_space = {
"iterations": Integer(100, 500),
"depth": Integer(4, 10),
"learning_rate": Real(0.01, 0.3, prior="log-uniform"),
"l2_leaf_reg": Real(1, 10),
"scale_pos_weight": Real(
1, class_weight_value * 2
), # Use Real instead of Categorical
}
# Optimize Logistic Regression
lr_bayes = BayesSearchCV(
LogisticRegression(random_state=42),
lr_param_space,
n_iter=50,
cv=5,
scoring=scoring,
refit="recall",
random_state=42,
n_jobs=-1,
)
# Optimize CatBoost
cat_bayes = BayesSearchCV(
CatBoostClassifier(random_state=42, verbose=False),
cat_param_space,
n_iter=50,
cv=5,
scoring=scoring,
refit="recall",
random_state=42,
n_jobs=-1,
)
# Fit the models
lr_bayes.fit(X_train_scaled, y_train)
cat_bayes.fit(X_train, y_train)
# Get best estimators
best_lr = lr_bayes.best_estimator_
best_cat = cat_bayes.best_estimator_
print("Logistic Regression Results:")
lr_results = evaluate_model(
best_lr, X_val_scaled, y_val, dataset_name="Validation", target_recall=0.9
)
print("\nCatBoost Results:")
cat_results = evaluate_model(
best_cat, X_val, y_val, dataset_name="Validation", target_recall=0.9
)
# Select the best model based on ROC AUC
if lr_results["roc_auc"] > cat_results["roc_auc"]:
best_model = best_lr
print("\nLogistic Regression selected as the best model.")
else:
best_model = best_cat
print("\nCatBoost selected as the best model.")
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Logistic Regression Results:
Adjusted threshold: 0.3773
Results on Validation set:
precision recall f1-score support
0 0.99 0.63 0.77 940
1 0.10 0.90 0.18 42
accuracy 0.65 982
macro avg 0.55 0.77 0.48 982
weighted avg 0.96 0.65 0.75 982
Confusion Matrix:
[[596 344]
[ 4 38]]
ROC AUC: 0.8418
PR AUC: 0.1670
F1 Score: 0.1792
Precision: 0.0995
Recall: 0.9048
Balanced Accuracy: 0.7694
CatBoost Results:
Adjusted threshold: 0.6279
Results on Validation set:
precision recall f1-score support
0 0.99 0.70 0.82 940
1 0.12 0.90 0.21 42
accuracy 0.70 982
macro avg 0.56 0.80 0.51 982
weighted avg 0.96 0.70 0.79 982
Confusion Matrix:
[[654 286]
[ 4 38]]
ROC AUC: 0.8665
PR AUC: 0.1861
F1 Score: 0.2077
Precision: 0.1173
Recall: 0.9048
Balanced Accuracy: 0.8003
CatBoost selected as the best model.
model_results = {"Logistic Regression": lr_results, "CatBoost": cat_results}
plot_model_performance(
model_results,
["roc_auc", "pr_auc", "f1", "precision", "recall", "balanced_accuracy"],
"../images/tuned_model_performance.png",
)
y_pred_dict = {
"Logistic Regression": lr_results["y_pred"],
"CatBoost": cat_results["y_pred"],
}
plot_combined_confusion_matrices(
model_results,
y_val,
y_pred_dict,
labels=["No Stroke", "Stroke"],
save_path="../images/tuned_confusion_matrices.png",
)
lr_importances = np.abs(best_lr.coef_[0])
cat_importances = best_cat.feature_importances_
feature_importances = {
"Logistic Regression": dict(zip(X_train.columns, lr_importances)),
"CatBoost": dict(zip(X_train.columns, cat_importances)),
}
plot_feature_importances(
feature_importances, save_path="../images/tuned_feature_importances.png"
)
Image(filename="../images/tuned_model_performance.png")
Image(filename="../images/tuned_confusion_matrices.png")
Image(filename="../images/tuned_feature_importances.png")
lr_false_negatives = X_val[(y_val == 1) & (lr_results["y_pred"] == 0)]
cat_false_negatives = X_val[(y_val == 1) & (cat_results["y_pred"] == 0)]
print("\nLogistic Regression False Negative Analysis:")
print(lr_false_negatives.describe())
print("\nCatBoost False Negative Analysis:")
print(cat_false_negatives.describe())
Logistic Regression False Negative Analysis:
age hypertension heart_disease ever_married residence_type \
count 4.000000 4.0 4.0 4.00 4.00
mean 49.750000 0.0 0.0 0.75 0.25
std 7.804913 0.0 0.0 0.50 0.50
min 39.000000 0.0 0.0 0.00 0.00
25% 46.500000 0.0 0.0 0.75 0.00
50% 52.000000 0.0 0.0 1.00 0.00
75% 55.250000 0.0 0.0 1.00 0.25
max 56.000000 0.0 0.0 1.00 1.00
avg_glucose_level bmi has_anomalies gender_female \
count 4.000000 4.000000 4.0 4.00
mean 114.457500 28.600000 0.0 0.75
std 32.220145 2.743477 0.0 0.50
min 92.980000 25.600000 0.0 0.00
25% 96.565000 26.875000 0.0 0.75
50% 101.310000 28.450000 0.0 1.00
75% 119.202500 30.175000 0.0 1.00
max 162.230000 31.900000 0.0 1.00
gender_male ... smoking_status_formerly_smoked \
count 4.00 ... 4.0
mean 0.25 ... 0.0
std 0.50 ... 0.0
min 0.00 ... 0.0
25% 0.00 ... 0.0
50% 0.00 ... 0.0
75% 0.25 ... 0.0
max 1.00 ... 0.0
smoking_status_never_smoked smoking_status_smokes age_glucose \
count 4.00 4.00000 4.000000
mean 0.25 0.50000 5787.390000
std 0.50 0.57735 2283.870414
min 0.00 0.00000 3812.640000
25% 0.00 0.00000 4788.585000
50% 0.00 0.50000 5126.020000
75% 0.25 1.00000 6124.825000
max 1.00 1.00000 9084.880000
age_hypertension age_heart_disease age_squared glucose_squared \
count 4.0 4.0 4.000000 4.000000
mean 0.0 0.0 2520.750000 13879.122625
std 0.0 0.0 740.864529 8349.215714
min 0.0 0.0 1521.000000 8645.280400
25% 0.0 0.0 2181.000000 9329.083300
50% 0.0 0.0 2713.000000 10276.318600
75% 0.0 0.0 3052.750000 14826.357925
max 0.0 0.0 3136.000000 26318.572900
bmi_age bmi_glucose
count 4.000000 4.000000
mean 1413.575000 3261.974250
std 185.146147 872.078959
min 1154.400000 2380.288000
25% 1344.600000 2765.344000
50% 1468.400000 3119.365000
75% 1537.375000 3615.995250
max 1563.100000 4428.879000
[8 rows x 27 columns]
CatBoost False Negative Analysis:
age hypertension heart_disease ever_married residence_type \
count 4.000000 4.0 4.0 4.00 4.00
mean 51.500000 0.0 0.0 0.75 0.25
std 10.115994 0.0 0.0 0.50 0.50
min 39.000000 0.0 0.0 0.00 0.00
25% 46.500000 0.0 0.0 0.75 0.00
50% 52.000000 0.0 0.0 1.00 0.00
75% 57.000000 0.0 0.0 1.00 0.25
max 63.000000 0.0 0.0 1.00 1.00
avg_glucose_level bmi has_anomalies gender_female \
count 4.000000 4.000000 4.0 4.00
mean 100.545000 27.750000 0.0 0.75
std 6.325238 3.655589 0.0 0.50
min 92.980000 23.900000 0.0 0.00
25% 96.565000 25.175000 0.0 0.75
50% 101.310000 27.600000 0.0 1.00
75% 105.290000 30.175000 0.0 1.00
max 106.580000 31.900000 0.0 1.00
gender_male ... smoking_status_formerly_smoked \
count 4.00 ... 4.0
mean 0.25 ... 0.0
std 0.50 ... 0.0
min 0.00 ... 0.0
25% 0.00 ... 0.0
50% 0.00 ... 0.0
75% 0.25 ... 0.0
max 1.00 ... 0.0
smoking_status_never_smoked smoking_status_smokes age_glucose \
count 4.00 4.00000 4.000000
mean 0.25 0.50000 5194.805000
std 0.50 0.57735 1187.396465
min 0.00 0.00000 3812.640000
25% 0.00 0.00000 4788.585000
50% 0.00 0.50000 5126.020000
75% 0.25 1.00000 5532.240000
max 1.00 1.00000 6714.540000
age_hypertension age_heart_disease age_squared glucose_squared \
count 4.0 4.0 4.000000 4.000000
mean 0.0 0.0 2729.000000 10139.303500
std 0.0 0.0 1031.514097 1263.964904
min 0.0 0.0 1521.000000 8645.280400
25% 0.0 0.0 2181.000000 9329.083300
50% 0.0 0.0 2713.000000 10276.318600
75% 0.0 0.0 3261.000000 11086.538800
max 0.0 0.0 3969.000000 11359.296400
bmi_age bmi_glucose
count 4.000000 4.000000
mean 1407.800000 2791.570000
std 180.659994 426.454922
min 1154.400000 2380.288000
25% 1344.600000 2505.518500
50% 1456.850000 2720.479000
75% 1520.050000 3006.530500
max 1563.100000 3345.034000
[8 rows x 27 columns]
n_neg = np.sum(y_train == 0)
n_pos = np.sum(y_train == 1)
if hasattr(best_lr, "class_weight") and best_lr.class_weight == "balanced":
wrapped_lr = CustomLogisticRegressionWrapper(
LogisticRegression(
**{k: v for k, v in best_lr.get_params().items() if k != "class_weight"}
),
{0: 1, 1: n_neg / n_pos},
)
else:
wrapped_lr = best_lr
ensemble_model = CustomVotingClassifier(
estimators=[("lr", wrapped_lr), ("cb", best_cat)], voting="soft"
)
ensemble_model.fit(X_train_scaled, y_train)
explainer = shap.TreeExplainer(ensemble_model.named_estimators_["cb"])
shap_values = explainer.shap_values(X_train)
feature_importance = pd.DataFrame(
{"feature": X_train.columns, "importance": np.abs(shap_values).mean(0)}
)
feature_importance = feature_importance.sort_values("importance", ascending=False)
print("Top 10 important features based on SHAP values:")
print(feature_importance.head(10))
top_features = feature_importance["feature"].head(10).tolist()
X_train_top = X_train[top_features]
X_val_top = X_val[top_features]
X_test_top = X_test[top_features]
scaler_top = StandardScaler()
X_train_top_scaled = scaler_top.fit_transform(X_train_top)
X_val_top_scaled = scaler_top.transform(X_val_top)
X_test_top_scaled = scaler_top.transform(X_test_top)
if hasattr(best_lr, "class_weight") and best_lr.class_weight == "balanced":
lr_model_top = CustomLogisticRegressionWrapper(
LogisticRegression(
**{k: v for k, v in best_lr.get_params().items() if k != "class_weight"}
),
{0: 1, 1: n_neg / n_pos},
)
else:
lr_model_top = LogisticRegression(**best_lr.get_params())
cb_model_top = CatBoostClassifier(**best_cat.get_params())
ensemble_model_top = CustomVotingClassifier(
estimators=[("lr", lr_model_top), ("cb", cb_model_top)], voting="soft"
)
ensemble_model_top.fit(X_train_top_scaled, y_train)
print("\nOriginal Ensemble Model Evaluation (Validation Set):")
original_val_results = evaluate_model(
ensemble_model, X_val_scaled, y_val, dataset_name="Validation", target_recall=0.9
)
print("\nTop 10 Features Ensemble Model Evaluation (Validation Set):")
top_features_val_results = evaluate_model(
ensemble_model_top,
X_val_top_scaled,
y_val,
dataset_name="Validation",
target_recall=0.9,
)
if top_features_val_results["roc_auc"] > original_val_results["roc_auc"]:
best_model = ensemble_model_top
best_X_test = X_test_top_scaled
print("\nTop 10 Features Ensemble Model selected as the best model.")
else:
best_model = ensemble_model
best_X_test = X_test_scaled
print("\nOriginal Ensemble Model selected as the best model.")
print("\nBest Model Evaluation on Test Set:")
test_results = evaluate_model(
best_model, best_X_test, y_test, dataset_name="Test", target_recall=0.9
)
Top 10 important features based on SHAP values:
feature importance
25 bmi_age 0.184783
23 age_squared 0.179852
20 age_glucose 0.144049
0 age 0.133999
24 glucose_squared 0.065473
21 age_hypertension 0.053011
6 bmi 0.052169
1 hypertension 0.029600
22 age_heart_disease 0.025940
5 avg_glucose_level 0.025844
Original Ensemble Model Evaluation (Validation Set):
Adjusted threshold: 0.4753
Results on Validation set:
precision recall f1-score support
0 0.99 0.62 0.77 940
1 0.10 0.90 0.18 42
accuracy 0.64 982
macro avg 0.55 0.76 0.47 982
weighted avg 0.95 0.64 0.74 982
Confusion Matrix:
[[587 353]
[ 4 38]]
ROC AUC: 0.8521
PR AUC: 0.1678
F1 Score: 0.1755
Precision: 0.0972
Recall: 0.9048
Balanced Accuracy: 0.7646
Top 10 Features Ensemble Model Evaluation (Validation Set):
Adjusted threshold: 0.5265
Results on Validation set:
precision recall f1-score support
0 0.99 0.69 0.82 940
1 0.12 0.90 0.21 42
accuracy 0.70 982
macro avg 0.56 0.80 0.51 982
weighted avg 0.96 0.70 0.79 982
Confusion Matrix:
[[651 289]
[ 4 38]]
ROC AUC: 0.8575
PR AUC: 0.1657
F1 Score: 0.2060
Precision: 0.1162
Recall: 0.9048
Balanced Accuracy: 0.7987
Top 10 Features Ensemble Model selected as the best model.
Best Model Evaluation on Test Set:
Adjusted threshold: 0.3003
Results on Test set:
precision recall f1-score support
0 0.99 0.47 0.64 940
1 0.07 0.90 0.13 42
accuracy 0.49 982
macro avg 0.53 0.69 0.39 982
weighted avg 0.95 0.49 0.62 982
Confusion Matrix:
[[445 495]
[ 4 38]]
ROC AUC: 0.8080
PR AUC: 0.2000
F1 Score: 0.1322
Precision: 0.0713
Recall: 0.9048
Balanced Accuracy: 0.6891
model_results = {
"Original Ensemble": original_val_results,
"Top 10 Features Ensemble": top_features_val_results,
}
y_pred_dict = {
"Original Ensemble": original_val_results["y_pred"],
"Top 10 Features Ensemble": top_features_val_results["y_pred"],
}
plot_model_performance(
model_results,
["roc_auc", "pr_auc", "f1", "precision", "recall", "balanced_accuracy"],
save_path="../images/ensemble_model_performance_comparison.png",
)
plot_combined_confusion_matrices(
model_results,
y_test,
y_pred_dict,
labels=["No Stroke", "Stroke"],
save_path="../images/ensemble_confusion_matrices_comparison.png",
)
original_importances = ensemble_model.named_estimators_["cb"].get_feature_importance()
top_features_importances = ensemble_model_top.named_estimators_[
"cb"
].get_feature_importance()
feature_importances = {
"Original Ensemble": dict(zip(X.columns, original_importances)),
"Top 10 Features Ensemble": dict(zip(top_features, top_features_importances)),
}
plot_feature_importances(
feature_importances,
save_path="../images/ensemble_feature_importances_comparison.png",
)
Image(filename="../images/ensemble_model_performance_comparison.png")
Image(filename="../images/ensemble_confusion_matrices_comparison.png")
Image(filename="../images/ensemble_feature_importances_comparison.png")
Final Model Selection: CatBoost
Based on the results provided of the ensemble model, we select the CatBoost model as our final model for stroke prediction. Here's why:
Superior Performance: CatBoost outperforms both Logistic Regression and the ensemble model across key metrics:
- ROC AUC: 0.8883 (highest among all models)
- PR AUC: 0.2314 (highest)
- F1 Score: 0.2484 (highest)
- Precision: 0.1439 (highest)
- Recall: 0.9048 (matches the target recall of other models)
- Balanced Accuracy: 0.8322 (highest)
Effective Handling of Class Imbalance: CatBoost maintains high recall (0.9048) on the positive class while achieving better precision than other models, crucial for imbalanced medical datasets.
Reduced False Positives: CatBoost produces fewer false positives (226) compared to Logistic Regression (309), which is important for minimizing unnecessary follow-ups or treatments.
Gradient Boosting Advantages: As a gradient boosting model, CatBoost can capture complex, non-linear relationships in the data, which may be particularly beneficial for stroke prediction given the intricate interplay of risk factors.
Key Performance Metrics (Validation Set):
- Recall: 0.9048
- Precision: 0.1439
- ROC AUC: 0.8883
- PR AUC: 0.2314
- F1 Score: 0.2484
- Balanced Accuracy: 0.8322
Model Behavior
- The model maintains the target high recall (0.9048) while achieving better precision than other approaches.
- The adjusted threshold of 0.5726 indicates a well-balanced decision boundary for this imbalanced dataset.
n_negative = np.sum(y_train == 0)
n_positive = np.sum(y_train == 1)
class_weight = {0: 1, 1: n_negative / n_positive}
scoring = {
"recall": "recall",
"precision": "precision",
"roc_auc": "roc_auc",
"avg_precision": "average_precision",
}
cat_param_space = {
"iterations": Integer(100, 500),
"depth": Integer(4, 10),
"learning_rate": Real(0.01, 0.3, prior="log-uniform"),
"l2_leaf_reg": Real(1, 10),
"scale_pos_weight": Categorical([1, n_negative / n_positive]),
}
stratified_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
cat_bayes = BayesSearchCV(
CatBoostClassifier(random_state=42, verbose=False),
cat_param_space,
n_iter=50,
cv=stratified_cv,
scoring=scoring,
refit="recall",
random_state=42,
n_jobs=-1,
)
cat_bayes.fit(X_train, y_train)
best_cat = cat_bayes.best_estimator_
print("\nCatBoost Results on Validation Set:")
cat_results = evaluate_model(
best_cat, X_val, y_val, dataset_name="Validation", target_recall=0.9
)
joblib.dump(best_cat, "../models/catboost_final_model.joblib")
joblib.dump(X_train.columns.tolist(), "../models/feature_names.joblib")
joblib.dump("catboost", "../models/best_model_type.joblib")
print("\nCatBoost model and feature names saved successfully.")
print("\nCatBoost Results on Test Set:")
cat_test_results = evaluate_model(
best_cat, X_test, y_test, dataset_name="Test", target_recall=0.9
)
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vytautasbunevicius/stroke-risk-predictor/.venv/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1517: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
CatBoost Results on Validation Set:
Adjusted threshold: 0.5104
Results on Validation set:
precision recall f1-score support
0 0.99 0.71 0.83 940
1 0.12 0.90 0.21 42
accuracy 0.71 982
macro avg 0.56 0.81 0.52 982
weighted avg 0.96 0.71 0.80 982
Confusion Matrix:
[[663 277]
[ 4 38]]
ROC AUC: 0.8621
PR AUC: 0.1668
F1 Score: 0.2129
Precision: 0.1206
Recall: 0.9048
Balanced Accuracy: 0.8050
CatBoost model and feature names saved successfully.
CatBoost Results on Test Set:
Adjusted threshold: 0.1686
Results on Test set:
precision recall f1-score support
0 0.99 0.44 0.61 940
1 0.07 0.90 0.12 42
accuracy 0.46 982
macro avg 0.53 0.67 0.37 982
weighted avg 0.95 0.46 0.59 982
Confusion Matrix:
[[411 529]
[ 4 38]]
ROC AUC: 0.8096
PR AUC: 0.1533
F1 Score: 0.1248
Precision: 0.0670
Recall: 0.9048
Balanced Accuracy: 0.6710
Summary¶
Overview This project focused on developing a machine learning model to predict the likelihood of stroke occurrence based on various patient attributes. Using the Stroke Prediction Dataset, we aimed to create a tool that could assist healthcare providers in identifying high-risk patients and potentially reduce the impact of this serious medical condition.
Key Steps
- Data Exploration and Preprocessing: We analyzed the dataset, handled missing values, and encoded categorical variables.
- Feature Engineering: Created interaction terms and polynomial features to capture complex relationships in the data.
- Statistical Analysis: Conducted tests to understand the relationships between various factors and stroke risk.
- Model Development: Experimented with multiple algorithms including Logistic Regression, XGBoost, LightGBM, and CatBoost.
- Model Optimization: Used Bayesian optimization for hyperparameter tuning and explored ensemble methods.
- Performance Evaluation: Focused on maximizing recall while maintaining acceptable precision, given the critical nature of stroke prediction.
Final Model After extensive experimentation, we selected the CatBoost model as our final predictor due to its superior performance across key metrics:
- Recall: 0.9048 (on validation set)
- Precision: 0.1206
- ROC AUC: 0.8621
- PR AUC: 0.1668
- F1 Score: 0.2129
- Balanced Accuracy: 0.8050
Key Findings
- Age and glucose levels were consistently the most important predictors of stroke risk.
- The model successfully maintains high recall, crucial for identifying potential stroke cases.
- There's a trade-off between precision and recall due to the imbalanced nature of the dataset.
Challenges and Future Work
- Dealing with class imbalance remained a significant challenge throughout the project.
- Future work could focus on gathering more data, especially for the minority class, to improve model performance.
- Exploring more advanced techniques like anomaly detection or semi-supervised learning could potentially yield better results.
Conclusion This project demonstrates the potential of machine learning in healthcare, particularly for risk prediction. While the model shows promising results in identifying high-risk patients, it's important to note that it should be used as a supportive tool in conjunction with clinical expertise, not as a standalone diagnostic system.